import pandas as pd
from AuxLibraries import *
import GWPs
from matplotlib import  colors,cm
from Functions.Miscellaneous import utils as ut


class CARBONPRICE():
    def __init__(self):
        pass
    
    def carbon_price_data(self, dtm):
            y = [n for n in range(2014,2025)]
            z = pd.ExcelFile('local_data/CarbonPriceWB_latest.xlsx')
            sn = z.sheet_names
            z = z.parse( 'Compliance_Price')
            z.columns = z.iloc[0]
            z = z.iloc[1:]
            z = z[z['Instrument Type'] == 'ETS']
            c = z['Jurisdiction Covered'].unique()
            res = pd.DataFrame()
            for i in c:
                tmp = z[z['Jurisdiction Covered'] == i]
                region_ = tmp.Region.iloc[0]
                tmp = tmp[y]
                tmp = pd.DataFrame(tmp.replace('-', np.nan).mean(axis = 0), columns = ['CarbonPrice']).reset_index().rename(columns = {0: 'year'})
                tmp['Country'] = [i]*len(tmp)
                tmp['Region'] = [region_]*len(tmp)
                res = pd.concat((res, tmp))
            res['Country'] = res['Country'].replace(GWPs.get_country_name())
            
            #==
            EU_map = dtm[['SP_GEOGRAPHY', 'Country_adj']].groupby('Country_adj').last().reset_index()
            EU_map = EU_map.rename(columns = {'SP_GEOGRAPHY': 'Country'})
            EU_map = EU_map[EU_map.Country_adj.isin(['Indonesia', 'Russian Federation (the)',]) == False]
            I = res.copy()
            I['Country'] = I['Country'].replace('EU', 'Europe')
            I = I.merge(EU_map)
            I = I[['year', 'CarbonPrice', 'Country_adj']].rename(columns = {'Country_adj': 'Country'})
            uk = I[I.Country == 'United Kingdom of Great Britain and Northern Ireland (the)']
            uk = uk[uk.year < 2020]
            I = I[I.Country != 'United Kingdom of Great Britain and Northern Ireland (the)']
            I = pd.concat((I, uk))
            res['Country'] = res['Country'].replace(['Saitama', 'Tokyo'], ['Japan', 'Japan'])
            res['Country'] = res['Country'].replace(['California', 'RGGI', 'Washington'], ['United States of America (the)']*3)          
            res = res[['Country', 'year', 'CarbonPrice']].groupby(['Country', 'year']).mean().reset_index()
            
            
            res = res[res.Country.isin(['Argentina', 'Australia', 'Canada', 'China', 'Mexico', 
                                        'Kazakhstan', 'Korea (the Republic of)', 'New Zealand', 'Switzerland'
                                        'South Africa',  'Indonesia', 'United Kingdom of Great Britain and Northern Ireland (the)',
                                        'United States of America (the)'])]
            res_other = res.copy()
            res = pd.concat((res, I))
            res = res.sort_values(by = ['Country', 'year']).rename(columns = {'Country': 'Country_adj'})
            res['mrg'] = res['Country_adj']+'-'+res['year'].astype(int).astype(str)
            res = res[['year', 'CarbonPrice', 'mrg']].groupby('mrg').mean().reset_index()
            res['Country_adj'] = res['mrg'].apply(lambda x: x.split('-')[0])
            res['year'] = res['year'].astype(int)
            #==
            dtm['mrg'] = dtm['Country_adj']+'-'+dtm['ReportingYear'].astype(int).astype(str)
            
            #=== Plot carbon price
            self.PlotCarbonPrice(res, res_other)
            return dtm, res

    def PlotCarbonPrice(self, res, res_other):
        A = res_other.copy()
        A = A[((A.Country == 'Australia') & (A.year == 2014) )==False]
        B = res[res.Country_adj.isin(res_other.Country.unique()) == False]
        B['Country'] = ['EU']*len(B)
        C = pd.concat((A, B[['Country', 'year', 'CarbonPrice']])).reset_index(drop=True)
        C['year'] = C['year'].astype(int)
        C = C[C.year < 2024]
        C['Country'] = C['Country'].replace('United Kingdom of Great Britain and Northern Ireland (the)', 'UK')
        C['Country'] = C['Country'].replace('United States of America (the)', 'USA')
        C = C.dropna()
        plt.figure()
        ax = plt.subplot()
        sns.pointplot(data = C, x = 'year', y = 'CarbonPrice', hue = 'Country', ls = '-.', palette = 'Set2')
        plt.legend(loc = 'center left', bbox_to_anchor = (1,0.5))
        plt.margins(x = 0)
        plt.xlabel('')
        plt.ylabel('Price of carbon, US\$/tCO2e')
        plt.savefig('Figures/FigureCPSI.pdf', dpi = 300, bbox_inches = 'tight')
    
    
    def DistributionBYSECTOR(self, dtm, Z, res, HORIZON=100):
        S = dtm.merge(res[['CarbonPrice', 'mrg']])
        print('Horizon:', HORIZON)
        if HORIZON == 100:
            latest_methane = 28.4
        elif HORIZON == 20:
            latest_methane = 81.1
        total_emissions = False
        A = GWPs.Counterfactual(S, Z, [2014,2015,2016,2017, 2018,2019,2020,2021], 'AR5',horizon=HORIZON) 
        B = GWPs.Counterfactual(S, Z, [2022, 2023], 'AR6', latest_methane,horizon=HORIZON) 
        X = pd.concat((A, B))   
        X['mrg'] = X['cdp_id'].astype(int).astype(str)+X['ReportingYear'].astype(str)+X['Gas']
        X = X.groupby('mrg').last().reset_index(drop=True)

        X  =X[X.main_sector.isin(['Other']) == False]
        if total_emissions:    
            A = X[X.Gas == 'CO2']
            B = X[X.Gas == 'CH4']
            A = A[['ReportingYear', 'cdp_id', 'CarbonPrice', 'Value (tCO2e)', 'main_sector']].rename(columns = {'Value (tCO2e)': 'CO2'})
            A['mrg'] = A['ReportingYear'].astype(int).astype(str)+'-'+A['cdp_id'].astype(int).astype(str)
            B = B[['ReportingYear', 'cdp_id', 'Value (tCO2e)', 'Value (Equalised)']].rename(columns = {'Value (tCO2e)': 'CH4 (R)', 'Value (Equalised)': 'CH4 (C)'})
            B['mrg'] = B['ReportingYear'].astype(int).astype(str)+'-'+B['cdp_id'].astype(int).astype(str)
            X = A.merge(B[['CH4 (R)', 'CH4 (C)', 'mrg']], on = 'mrg')
            X['Value (tCO2e)'] = X['CO2'].astype(float)+X['CH4 (R)'].astype(float)
            X['Value (Equalised)'] = X['CO2'].astype(float)+X['CH4 (C)']
        else:
            X = X[X.Gas == 'CH4']


        #==============================
        X['Value (tCO2e)'] = X['Value (tCO2e)'].astype(float)
        X['Cost based on reported values'] = X['Value (tCO2e)'].astype(float)*X['CarbonPrice']
        X['Cost based on counterfactuals'] = X['Value (Equalised)'].astype(float)*X['CarbonPrice']
        X['differential_cost'] = -(X['Cost based on counterfactuals'] - X['Cost based on reported values'])
        
        plot_maps=True
        if plot_maps:
            X[['differential_cost', 'main_sector', 'Country_adj', 'ReportingYear']].to_csv('local_data_for_geo_plot_cost.csv')
        
        #=== Plot in mean
        M = X[['main_sector', 'Cost based on reported values', 'Cost based on counterfactuals']].groupby('main_sector').mean()
        E = X[['main_sector', 'Cost based on reported values', 'Cost based on counterfactuals']].groupby('main_sector').apply(ut.utils().stdErr)
        M/=1e6
        E/=1e6
        cmap =  cm.get_cmap('coolwarm')
        cols = [colors.rgb2hex(cmap(i)) for i in np.linspace(0,1,10)]
        order = M.sort_values(by = ['Cost based on counterfactuals'])
        count=0
        order_dict=dict()
        for i in order.index:
            order_dict[i] = cols[count]
            count+=1
        M = M.loc[order.index]
        E = E.loc[order.index]
        M.plot(kind = 'bar',yerr=  E, capsize = 4,color = ['#ff9900', '#0066ff'])
        plt.ylabel('Price that the average company \n pays for methane emissions (MM$)')
        plt.xlabel('')
        #=== Plot in distribution
        plt.figure()
        K = X[['main_sector', 'Cost based on reported values', 'Cost based on counterfactuals']].set_index('main_sector').stack().reset_index()
        K[0]/=1e6
        K = K.set_index('main_sector')
        K = K.loc[order.index].reset_index()
        sns.boxplot(y = 'main_sector', x = 0, hue = 'level_1', data = K, 
                    showfliers = False, showmeans = False, boxprops=dict(alpha=.5),
                    meanprops={'marker':'o','markerfacecolor':'k','markeredgecolor':'b','markersize':'10'},
                    palette = ['#ff9900', '#0066ff'])
        sns.pointplot(y = 'main_sector', x = 0, hue = 'level_1', data = K,dodge=.4, capsize=0.5, linestyles='', palette = ['#ff9900', '#0066ff'],legend=False)
        plt.xlabel('Price to pay for methane emissions (MM$)')
        plt.ylabel('')
        plt.legend(title = '')
        return X, order_dict
    
    def PriceGAPCS(self, X, order, ax, HRZ, label_panels = ['a', 'b', 'c'], aggregation_method='sum'):
        #== Differential cost by sector
        dim = 'main_sector'
        sector_order = list(order.keys())
        cols = list(order.values())
        
        K = X.copy()
        #=== Make sure to eliminate duplicated observations, if any
        K['mrg'] = K['cdp_id'].astype(int).astype(str)+'-'+K['ReportingYear'].astype(int).astype(str)
        K = K.groupby('mrg').last().reset_index()
        K['differential_cost'] = -(K['Cost based on counterfactuals'] - K['Cost based on reported values'])
        K = K.loc[K['differential_cost'].dropna().index].reset_index(drop=True)
        
        
        #=== Again, make sure to eliminate duplicated observations, if any       
        K = K.groupby('mrg').last().reset_index()
        print('#====================================')
        print('Number of firms:', len(K.cdp_id.unique()))
        print('Number of observations:', len(K))        
        print('#====================================')
        M = K[['ReportingYear', dim, 'differential_cost']].\
            groupby(['ReportingYear', dim]).sum()
        rel_var = 'differential_cost'
            
            
        M/=1e6
        M = M.reset_index()
        M = M.pivot(index='ReportingYear', columns=dim, values=rel_var)
        M = M[sector_order]
        M = M.cumsum()

        #==== Panel A
        M.plot(kind = 'area',stacked = True ,lw=2, color = cols, ax = ax, legend=False)
        plt.margins(x = 0)
        if aggregation_method == 'sum':
            plt.ylabel('Economic impact, cumulative $MM', fontsize = 24)
        else:
            plt.ylabel('-------------')
            
        plt.xlabel('')
        M.columns = ['']*len(M.columns)
        M.plot(stacked = True ,lw=2, c='k',ax = ax,legend=False)
        plt.xlabel('')
        plt.legend(loc = 'center', bbox_to_anchor = (0.5,1.2), ncol=  4)
        ax.spines[['right', 'top']].set_visible(False)
        props = dict(boxstyle='round', facecolor='#7ac2e0', alpha=0.5)
        if HRZ == 100:
            ax.text(2014.5, -500, 'Counterfactual = GWP$_{'+str(HRZ)+'}$', bbox=props)
        if HRZ == 20:
            ax.text(2014.5, -10000, 'Counterfactual = GWP$_{'+str(HRZ)+'}$', bbox=props)
            plt.ylabel('')
        ax.text(-0.05, 1.05, label_panels[0], transform=ax.transAxes,
          fontsize=24, fontweight='bold', va='top', ha='right')
        if label_panels[0] != 'a':
            ax.legend().remove()
        else:
            handles, labels = ax.get_legend_handles_labels()
            ax.legend(handles[::-1], labels[::-1], loc = 'center', bbox_to_anchor = (0.5,1.2), ncol=  4)
        #========================================v
        
        #==== Panel B
        left, bottom, width, height = [.08, 0.14, 0.3, 0.4]
        ax2 = ax.inset_axes([left, bottom, width, height])
        count=0
        for s in sector_order:
            tmp = K[K.main_sector == s]
            sns.ecdfplot(data = tmp[rel_var]/1e6,color= order[s], ax = ax2, 
                         legend = True, linewidth = 3); 
            ax2.ticklabel_format(axis='x', style='plain', scilimits=(7,7))
            ax2.set_yscale('log')
            count+=1
        ax2.set_xlabel('Economic impact, $MM', fontsize = 18)
        ax2.set_ylabel('')
        ax2.spines[['right', 'top']].set_visible(False)
        ax2.text(-0.05, 1.25, label_panels[1], transform=ax2.transAxes,
          fontsize=24, fontweight='bold', va='top', ha='right')
        #========================================v
   
        #==== Panel C
        if HRZ == 100:
            left, bottom, width, height = [.5, 0.1, 0.3, 0.4]
            ax0 = ax.inset_axes([left, bottom, width, height])
            
            U = K.loc[K.differential_cost.dropna().index].reset_index(drop=True)
            z = U[U.main_sector.isin(sector_order)]
            z = z[['ReportingYear', 'main_sector', 'cdp_id']].groupby(['ReportingYear', 'main_sector']).nunique().reset_index()
            z = z.pivot(index='ReportingYear', columns='main_sector', values= 'cdp_id')
            z = z[sector_order]
            z.index = z.index.astype(str)
            z.plot(kind = 'area', stacked=True, ax = ax0, color = cols, legend = False)
            ax0.margins(x=0)
            ax0.set_xlabel('')
            ax0.set_ylabel('# of companies')
            ax0.spines[['right', 'top']].set_visible(False)
            ax0.text(-0.05, 1.25, label_panels[2], transform=ax0.transAxes,
                     fontsize=24, fontweight='bold', va='top', ha='right')
        return K
    def PriceGAP_BYREGION(self, data_type, panel_label,ax2,H,ordered=[],legend = True):
        
        ciso = pd.read_excel('local_data/country_iso.xlsx')
        isor = pd.read_excel('local_data/iso_region.xlsx')
        isor['Region'] = isor['Region'].apply(lambda x: x.split('+ \xa0')[1])
        ciso = ciso[['Country', 'ISO-3']]
        ciso.columns = ['Country', 'loc' ]
        ciso = isor.merge(ciso)
        data = data_type
        data['differential_cost'] = data['differential_cost'].astype(float)
        data['Country'] = data['Country'].replace(GWPs.get_country_name())
        data = data.merge(ciso[['Country', 'loc']].rename(columns = {'Country': 'Country_adj'}))
        cmap =  cm.get_cmap('coolwarm')
        cols = [colors.rgb2hex(cmap(i)) for i in np.linspace(0,1,5)]
        isor = pd.read_excel('local_data/iso_region.xlsx')
        isor['Region'] = isor['Region'].apply(lambda x: x.split('+ \xa0')[1])
        data = data.merge(isor)
        data['Region'] = data['Region'].replace(['Australia and New Zealand', 'Central Asia', 'Eastern Asia'], ['Asia-Pacific']*3)
        data = data[data.Region.isin(['Northern America', 'Northern Europe', 'Southern Europe', 'Western Europe', 'Asia-Pacific'])]
        
        if len(ordered) == 0:
            ordered =  data[['Region', 'differential_cost']].groupby('Region').quantile(0.01).sort_values(by = 'differential_cost', ascending=False)

        count=0
        for i in ordered.index:
            tmp = data[data.Region == i]['differential_cost']
            sns.ecdfplot(data = tmp/1e6,color=cols[count], ax = ax2, lw=2); 
            ax2.ticklabel_format(axis='x', style='plain', scilimits=(7,7))
            count+=1
        ax2.set_yscale('log')
        if legend:
            plt.legend(labels = list(ordered.index), loc = 'center', bbox_to_anchor = (0.5,1.1))
        ax2.set_xlabel('Economic impact, $MM', fontsize = 18)
        ax2.set_ylabel('')
        ax2.spines[['right', 'top']].set_visible(False)
        ax2.text(-0.05, 1.25, panel_label, transform=ax2.transAxes,
          fontsize=24, fontweight='bold', va='top', ha='right')

        props = dict(boxstyle='round', facecolor='#7ac2e0', alpha=0.5)
        if H==100:
            plt.text(-35, 0.1, 'Counterfactual = GWP$_{100}$', bbox=props) #,fontsize=18)
        else:
            plt.text(-1000, 0.1, 'Counterfactual = GWP$_{20}$', bbox=props) #,fontsize=18)
        if H == 100:
            return data, ordered
        else:
            return data
    
    
    
    
    
    

    
